import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
import pickle
import time
import threading
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque
import random
from google.colab import drive
drive.mount('/content/drive')
import os

# Neural Network (QNetwork)
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def push(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return np.stack(state), np.stack(action), np.stack(reward), np.stack(next_state), np.stack(done)

    def __len__(self):
        return len(self.buffer)

# Update the model
def update_model(q_online, q_target, optimizer, batch, gamma):
    state, action, reward, next_state, done = batch

    q_value = q_online(torch.FloatTensor(state)).gather(1, torch.LongTensor(action).unsqueeze(1))
    next_q_value = q_target(torch.FloatTensor(next_state)).max(dim=1, keepdim=True)[0].detach()
    target = torch.FloatTensor(reward) + gamma * next_q_value * (1 - torch.FloatTensor(done))

    loss = nn.MSELoss()(q_value, target.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Epsilon-Greedy Action Selection
def epsilon_greedy(q_network, state, epsilon, action_space):
    if np.random.rand() < epsilon:
        return np.random.choice(action_space)
    else:
        with torch.no_grad():
            q_values = q_network(torch.FloatTensor(state))
            return np.argmax(q_values.cpu().numpy())

# Simulate Data Rate (used for encrypted transmission simulation)
def simulate_data_rate(action='send', size=1024):
    start_time = time.time()

    if action == 'send':
        time.sleep(size / 1e6)  # Simulate the delay based on size and data rate (1 MBps)

    end_time = time.time()
    elapsed_time = end_time - start_time
    data_rate = (size * 8) / (elapsed_time * 1e6)  # in Mbps

    return data_rate

# Encryption and Decryption (for Federated Learning with Security)
def generate_dh_parameters():
    parameters = dh.generate_parameters(generator=2, key_size=2048, backend=default_backend())
    return parameters

def generate_dh_keypair(parameters):
    private_key = parameters.generate_private_key()
    public_key = private_key.public_key()
    return private_key, public_key

def derive_shared_key(private_key, peer_public_key):
    peer_public_key = serialization.load_pem_public_key(peer_public_key, backend=default_backend())
    shared_key = private_key.exchange(peer_public_key)
    return shared_key

def encrypt_message(shared_key, message):
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=b'salt',
        iterations=100000,
        backend=default_backend()
    )
    key = kdf.derive(shared_key)
    cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
    encryptor = cipher.encryptor()
    padded_message = message + b' ' * (16 - len(message) % 16)  # Padding to block size
    encrypted_message = encryptor.update(padded_message) + encryptor.finalize()
    return encrypted_message

def decrypt_message(shared_key, encrypted_message):
    kdf = PBKDF2HMAC(
        algorithm=hashes.SHA256(),
        length=32,
        salt=b'salt',
        iterations=100000,
        backend=default_backend()
    )
    key = kdf.derive(shared_key)
    cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_message = decryptor.update(encrypted_message) + decryptor.finalize()
    return decrypted_message.strip()

# Federated DDQN with/without Security
def federated_learning(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, secure=False):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())
    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    data_rates = []  # Track data rate

    # Set up Diffie-Hellman parameters and keys
    if secure:
        parameters = generate_dh_parameters()
        server_private_key, server_public_key = generate_dh_keypair(parameters)
        client_private_key, client_public_key = generate_dh_keypair(parameters)
        server_public_key_bytes = server_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        client_public_key_bytes = client_public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        shared_key_server = derive_shared_key(server_private_key, client_public_key_bytes)
        shared_key_client = derive_shared_key(client_private_key, server_public_key_bytes)
        print("DH key exchange complete.")

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

                if secure:
                    # Encrypt model parameters (weights) before sending
                    model_params = pickle.dumps(q_online.state_dict())
                    encrypted_params = encrypt_message(shared_key_server, model_params)
                    # Simulate sending encrypted parameters
                    data_rate = simulate_data_rate(action='send', size=len(encrypted_params))
                    data_rates.append(data_rate)
                    # Decrypt model parameters on the receiving end
                    decrypted_params = decrypt_message(shared_key_client, encrypted_params)
                    q_online.load_state_dict(pickle.loads(decrypted_params))
                else:
                    # Simulate sending model parameters (weights) without encryption
                    model_params = pickle.dumps(q_online.state_dict())
                    data_rate = simulate_data_rate(action='send', size=len(model_params))
                    data_rates.append(data_rate)

            state = next_state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, np.mean(data_rates)

# Distributed DDQN with malicious agents
def distributed_ddqn(env_name, num_agents, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, malicious_agents=[]):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    agent_models = [QNetwork(input_dim, output_dim) for _ in range(num_agents)]
    target_models = [QNetwork(input_dim, output_dim) for _ in range(num_agents)]
    for target_model, agent_model in zip(target_models, agent_models):
        target_model.load_state_dict(agent_model.state_dict())

    optimizers = [optim.Adam(agent_model.parameters()) for agent_model in agent_models]
    replay_buffers = [ReplayBuffer(capacity=10000) for _ in range(num_agents)]

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    data_rates = []  # Track data rate

    for episode in range(num_episodes):
        states = [env.reset() for _ in range(num_agents)]
        done = [False] * num_agents
        score = [0] * num_agents

        while not any(done):
            actions = [
                epsilon_greedy(agent_models[i], states[i], epsilon, range(output_dim)) for i in range(num_agents)
            ]
            next_states, rewards, done, _ = zip(*[env.step(action) for action in actions])
            for i in range(num_agents):
                replay_buffers[i].push((states[i], actions[i], rewards[i], next_states[i], done[i]))
                score[i] += rewards[i]

                if len(replay_buffers[i]) > batch_size:
                    batch = replay_buffers[i].sample(batch_size)
                    update_model(agent_models[i], target_models[i], optimizers[i], batch, gamma)

                    # Simulate sending model parameters (weights)
                    model_params = pickle.dumps(agent_models[i].state_dict())
                    data_rate = simulate_data_rate(action='send', size=len(model_params))
                    data_rates.append(data_rate)

            # Poisonous Attack: Modify the model updates for malicious agents
            for malicious_agent in malicious_agents:
                if malicious_agent < num_agents:
                    agent_models[malicious_agent].load_state_dict(target_models[malicious_agent].state_dict())

            states = next_states

        scores.append(np.mean(score))
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            for i in range(num_agents):
                target_models[i].load_state_dict(agent_models[i].state_dict())

    env.close()
    return avg_scores, np.mean(data_rates)

# Standalone DDQN
def standalone_dqn(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_network = QNetwork(input_dim, output_dim)
    optimizer = optim.Adam(q_network.parameters())
    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    data_rates = []

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_network, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_network, q_network, optimizer, batch, gamma)
                # Simulate sending model parameters (weights)
                model_params = pickle.dumps(q_network.state_dict())
                data_rate = simulate_data_rate(action='send', size=len(model_params))
                data_rates.append(data_rate)

            state = next_state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

    env.close()
    return avg_scores, np.mean(data_rates)


import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import deque

# Assuming QNetwork, ReplayBuffer, update_model, and other functions are defined as per previous content

# Simulating RSSI (Received Signal Strength Indicator) values
def simulate_rssi(initial_rssi, decay_rate=0.01, noise_level=2.0):
    """
    Function to simulate RSSI values.
    We simulate a simple decay in RSSI with some noise fluctuation.
    """
    # RSSI decays over time (simulating distance)
    rssi = initial_rssi * (1 - decay_rate)
    # Add random noise to simulate environmental fluctuations
    noise = np.random.normal(0, noise_level)  # Normal distribution noise
    rssi += noise
    # Ensure RSSI doesn't go below a reasonable threshold (e.g., -120 dBm)
    rssi = max(rssi, -120)
    return rssi

# Federated DDQN with or without security (simulating the federated learning process)
def federated_learning(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, secure=False):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())

    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    rssi_values = []  # To track RSSI over episodes

    # Initializing RSSI for federated without security and with security
    rssi_current = -50  # Typical starting RSSI
    rssi_secure_current = -55  # Slightly worse RSSI for encrypted federated learning

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)
            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

            state = next_state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        # Simulating RSSI value for this episode
        rssi_current = simulate_rssi(rssi_current)
        rssi_values.append(rssi_current)

        if secure:
            rssi_secure_current = simulate_rssi(rssi_secure_current)
            rssi_values.append(rssi_secure_current)

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, rssi_values



# Federated DDQN with IP Security
# Federated DDQN with IP Security
def federated_ddqn_with_security(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, malicious_ips=[], congestion_rate=0.05):
    """
    Simulates Federated DDQN with IP Security, accepting or denying updates based on IP verification.
    """
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())
    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    congestion_counts = []  # To track congestion events
    packet_drops = []  # To track packet drops
    rssi_values = []  # To track RSSI

    # Simulating IP security for federated learning
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        score = 0
        congestion_events = 0
        packet_drop_events = 0
        rssi_readings = []
        ip_address = random.choice(malicious_ips + ['192.168.1.1', '192.168.1.2'])  # Simulating random IP addresses

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            packet_drop, congestion, rssi = simulate_network_conditions(congestion_rate=congestion_rate)

            if packet_drop:
                packet_drop_events += 1
                reward = -1  # Penalize for packet drop
                next_state = state  # Keep the current state when packet is dropped
            elif congestion:
                congestion_events += 1
                reward = -0.5  # Penalize for congestion
                next_state = state  # Keep the current state when there is congestion
            else:
                next_state, reward, done, _ = env.step(action)  # Normal state transition

            rssi_readings.append(rssi)

            # Simulate IP Security check for malicious updates
            if ip_security_check(ip_address, malicious_ips):
                replay_buffer.push((state, action, reward, next_state, done))
            else:
                reward = -2  # Heavier penalty for denied updates from malicious agents

            score += reward
            state = next_state

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        congestion_counts.append(congestion_events)
        packet_drops.append(packet_drop_events)
        rssi_values.append(np.mean(rssi_readings))  # Averaging RSSI over the episode

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Periodic update of target network
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, congestion_counts, packet_drops, rssi_values







def epsilon_greedy(q_network, state, epsilon, action_space):
    if np.random.rand() < epsilon:
        return np.random.choice(action_space)
    else:
        with torch.no_grad():
            q_values = q_network(torch.FloatTensor(state))
            return np.argmax(q_values.cpu().numpy())

def simulate_congestion(action='send', congestion_rate=0.1):
    return np.random.rand() < congestion_rate

def simulate_packet_drop(probability=0.1):
    return np.random.rand() < probability

def simulate_rssi():
    return np.random.normal(loc=-70, scale=5)  # Simulate RSSI values (in dBm)

def verify_ip_security(malicious_ips, agent_ip):
    # Simulate security verification
    return agent_ip not in malicious_ips

def federated_ddqn_with_security(env_name, num_episodes, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, malicious_ips, congestion_rate):
    env = gym.make(env_name)
    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    q_online = QNetwork(input_dim, output_dim)
    q_target = QNetwork(input_dim, output_dim)
    q_target.load_state_dict(q_online.state_dict())

    optimizer = optim.Adam(q_online.parameters())

    replay_buffer = ReplayBuffer(capacity=10000)

    epsilon = epsilon_start
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay

    scores = []
    avg_scores = []
    rssi_readings = []
    congestion_readings = []
    packet_drops = []

    malicious_agents = ['malicious']  # Example malicious IPs

    for episode in range(num_episodes):
        state = env.reset()  # Ensure proper reset handling
        done = False
        score = 0

        while not done:
            action = epsilon_greedy(q_online, state, epsilon, range(output_dim))
            next_state, reward, done, _ = env.step(action)

            # Simulate network KPIs
            rssi = simulate_rssi()
            rssi_readings.append(rssi)

            packet_drop = simulate_packet_drop(probability=0.1)
            packet_drops.append(packet_drop)

            congestion = simulate_congestion(action='send', congestion_rate=congestion_rate)
            congestion_readings.append(congestion)

            # Security check - only allow actions from non-malicious agents
            if not verify_ip_security(malicious_agents, str(action)):
                reward = -10  # Penalize for malicious agents

            replay_buffer.push((state, action, reward, next_state, done))
            score += reward

            if len(replay_buffer) > batch_size:
                batch = replay_buffer.sample(batch_size)
                update_model(q_online, q_target, optimizer, batch, gamma)

            state = next_state  # Proceed to the next state

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)

        if episode % 100 == 0:
            print(f"Episode {episode}, Average Score: {avg_score:.2f}, Epsilon: {epsilon:.2f}")

        epsilon = max(epsilon - epsilon_decay_rate, epsilon_end)

        # Update target network periodically
        if episode % 1000 == 0:
            q_target.load_state_dict(q_online.state_dict())

    env.close()
    return avg_scores, rssi_readings, packet_drops, congestion_readings




# Simulating network conditions (Packet drops, RSSI, Congestion)
def simulate_network_conditions(packet_drop_rate=0.1, congestion_rate=0.05):
    """
    Function to simulate network packet drops, congestion, and RSSI.
    Returns: packet_drop (True/False), congestion (True/False), RSSI (value).
    """
    packet_drop = np.random.rand() < packet_drop_rate
    congestion = np.random.rand() < congestion_rate
    rssi = np.random.normal(-70, 10)  # Simulate RSSI around -70 dBm with noise
    return packet_drop, congestion, rssi
import sympy
print(sympy.__version__)
# Simulating IP Security (Poisonous Attack)
def ip_security_check(ip_address, malicious_ips):
    """
    Simulate an IP security check to verify if an update should be accepted based on the IP address.
    """
    return ip_address not in malicious_ips

def run_simulations():
    # Define simulation parameters
    num_episodes = 100
    batch_size = 64
    gamma = 0.99
    epsilon_start = 1.0
    epsilon_end = 0.01
    epsilon_decay = 1000
    congestion_rate = 0.1  # Example congestion rate

    # Run Federated DDQN with security
    print("Running Federated DDQN with Security...")
    federated_with_security = federated_ddqn_with_security(
        env_name='CartPole-v1', num_episodes=num_episodes, batch_size=batch_size, gamma=gamma,
        epsilon_start=epsilon_start, epsilon_end=epsilon_end, epsilon_decay=epsilon_decay,
        malicious_ips=['malicious'], congestion_rate=congestion_rate
    )

    

    

    # Plot the results for all the algorithms
    plot_results(federated_with_security)
# Example of saving to Google Drive
#save_dir = '/content/drive/My Drive/SEC'
#os.makedirs(save_dir, exist_ok=True)


# Define latency function
def calculate_latency(rounds, congestion_rate):
    """
    Simulates latency as a function of rounds and congestion rate.
    """
    base_latency = 17 # Base latency in ms
    noise = np.random.normal(0, 5, len(rounds))  # Add small variability
    return base_latency * (1 + congestion_rate) + noise

# Define accuracy function
def calculate_accuracy(rewards):
    """
    Simulates accuracy based on reward values.
    """
    return np.clip(np.array(rewards) / max(rewards) * 100, 0, 100)





save_dir = '/content/drive/My Drive/SEC'
os.makedirs(save_dir, exist_ok=True)
# Updated plot function for barcharts
def plot_latency_and_accuracy_barchart(latency, accuracy, labels):
    """
    Plots latency and accuracy as bar charts for all models.
    """
    # Aggregate average latency and accuracy
    avg_latency = [np.mean(l) for l in latency]
    avg_accuracy = [np.mean(a) for a in accuracy]

    x = np.arange(len(labels))  # Label positions
    width = 0.4  # Bar width

    plt.figure(figsize=(10, 6))

    # Latency Bar Chart
    plt.bar(x - width/2, avg_latency, width, label='Latency (ms)', color='skyblue')

    # Accuracy Bar Chart
    plt.bar(x + width/2, avg_accuracy, width, label='Accuracy (%)', color='salmon', alpha=1)

    # Adding labels and formatting
    plt.xlabel('Models', fontsize=16)
    plt.ylabel('Values', fontsize=19)
    #plt.title('Latency and Accuracy Comparison', fontsize=16)
    plt.xticks(x, labels, fontsize=18)
    plt.legend(fontsize=20, loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2)
    plt.grid(linestyle='--', alpha=1)
    #plt.grid(linestyle='--', alpha=1)

    # Save plot as PDF
    plt.tight_layout()
    plt.savefig('/content/drive/My Drive/SEC/Latency_Accuracy_BarchartK1.pdf')
    plt.show()

# Simulated results for the four models
def simulate_results():
    rounds = range(1, 101)  # Simulate for 100 rounds
    labels = ["FLDDQN w/ Sec"]

    # Use rewards directly from provided functions
    federated_with_security_rewards = [r for r in federated_with_security[0]]
    
    rewards = [
        federated_with_security_rewards,
        
    ]

    # Calculate latency and accuracy for each model
    congestion_rates = [0.1, 0.15, 0.2, 0.25]
    latency = [calculate_latency(rounds, congestion_rates[i]) for i in range(4)]
    accuracy = [calculate_accuracy(rewards[i]) for i in range(4)]

    
    # Plot results
    plot_latency_and_accuracy_barchart(latency, accuracy, labels)


num_episodes = 5
def simulate_results(federated_with_security):
    rounds = range(1, 101)  # Simulate for 100 rounds
    labels = ["FLDDQN w/Sec"]

    # Use rewards directly from provided outputs
    federated_with_security_rewards = [r for r in federated_with_security[0]]
    
    rewards = [
        federated_with_security_rewards,
        
    ]

    # Calculate latency and accuracy for each model
    congestion_rates = [0.1, 0.15, 0.2, 0.25]
    latency = [calculate_latency(rounds, congestion_rates[i]) for i in range(4)]
    accuracy = [calculate_accuracy(rewards[i]) for i in range(1)]
    #print(f"Latency (ms): {latency}, Accuracy (%): {accuracy}")
    print(f"labels:labels, Accuracy (%): {accuracy}")
    print(f"labels:labels, Latency (ms): {latency}")
    for i in range(num_episodes):
        print(f"Rounds {i+1}/{num_episodes}:")
        print(f"  FLDDQN w/ Sec -> Latency: {latency}, Accuracy: {accuracy}%")
        



    
        # Plot results
    plot_latency_and_accuracy_barchart(latency, accuracy, labels)
        

if __name__ == '__main__':
    # Assuming these functions return results in the format expected
    federated_with_security = federated_ddqn_with_security(
        env_name='CartPole-v1', num_episodes=100, batch_size=64, gamma=0.99,
        epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=1000,
        malicious_ips=['malicious'], congestion_rate=0.1
    )

    

    
    

    # Run the simulations with the actual results
    simulate_results(federated_with_security)

























if __name__ == '__main__':
    # Run all simulations and plot the KPIs
    run_simulations()
